Source code for hysop.tools.field_utils

# Copyright (c) HySoP 2011-2024
#
# This file is part of HySoP software.
# See "https://particle_methods.gricad-pages.univ-grenoble-alpes.fr/hysop-doc/"
# for further info.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from hysop.tools.htypes import first_not_None, to_tuple
from hysop.tools.sympy_utils import (
    nabla,
    partial,
    subscript,
    subscripts,
    exponent,
    exponents,
    xsymbol,
    get_derivative_variables,
)

import sympy as sm
from sympy.printing.str import StrPrinter, StrReprPrinter
from sympy.printing.latex import LatexPrinter
from packaging import version

if version.parse(sm.__version__) > version.parse("1.7"):
    from sympy.printing.c import C99CodePrinter
else:
    from sympy.printing.ccode import C99CodePrinter


[docs] class BasePrinter:
[docs] def print_Derivative(self, expr): (bvar, pvar, vvar, lvar) = print_all_names(expr.args[0]) pvar = pvar all_xvars = get_derivative_variables(expr) xvars = tuple(set(all_xvars)) varpows = tuple(all_xvars.count(x) for x in xvars) bxvars = tuple(print_name(x) for x in xvars) pxvars = tuple(print_pretty_name(x) for x in xvars) vxvars = tuple(print_var_name(x) for x in xvars) lxvars = tuple(print_latex_name(x) for x in xvars) return DifferentialStringFormatter.format_pd( bvar, pvar, vvar, lvar, bxvars, pxvars, vxvars, lxvars, varpows=varpows )
def _print(self, expr, **kwds): try: return super()._print(expr, **kwds) except: print msg = "FATAL ERROR: {} failed to print expression {}." msg = msg.format(type(self).__name__, expr) print(msg) print raise
[docs] class NamePrinter(BasePrinter, StrReprPrinter): def _print(self, expr, **kwds): if hasattr(expr, "name"): return expr.name elif hasattr(expr, "_name"): return expr._name return super()._print(expr, **kwds) def _print_Derivative(self, expr): return super().print_Derivative(expr)[0] def _print_Add(self, expr): return super()._print_Add(expr).replace(" ", "") def _print_Mul(self, expr): return super()._print_Mul(expr).replace(" ", "")
[docs] def emptyPrinter(self, expr): msg = "\n{} does not implement _print_{}(self, expr)." msg += f"\nExpression is {expr}." msg += "\nExpression type MRO is:" msg += "\n *" + "\n *".join(t.__name__ for t in type(expr).__mro__) msg = msg.format(self.__class__.__name__, expr.__class__.__name__) raise NotImplementedError(msg)
[docs] class PrettyNamePrinter(BasePrinter, StrPrinter): def _print(self, expr, **kwds): if hasattr(expr, "pretty_name"): return expr.pretty_name elif hasattr(expr, "_pretty_name"): return expr._pretty_name return super()._print(expr, **kwds) def _print_Derivative(self, expr): return super().print_Derivative(expr)[1]
[docs] def emptyPrinter(self, expr): msg = "\n{} does not implement _print_{}(self, expr)." msg += f"\nExpression is {expr}." msg += "\nExpression type MRO is:" msg += "\n *" + "\n *".join(t.__name__ for t in type(expr).__mro__) msg = msg.format(self.__class__.__name__, expr.__class__.__name__) raise NotImplementedError(msg)
[docs] class VarNamePrinter(BasePrinter, C99CodePrinter): def _print(self, expr, **kwds): if hasattr(expr, "var_name"): return expr.var_name elif hasattr(expr, "_var_name"): return expr._var_name return super()._print(expr, **kwds).replace(" ", "") def _print_Derivative(self, expr): return super().print_Derivative(expr)[2] def _print_Add(self, expr): s = super()._print_Add(expr) s = s.replace(" + ", "_plus_").replace(" - ", "_minus_") s = s.replace("+", "plus_").replace("-", "minus_") return s def _print_Mul(self, expr): s = super()._print_Mul(expr) s = s.replace(" * ", "_times_").replace("+", "plus_").replace("-", "minus_") return s
[docs] def emptyPrinter(self, expr): msg = "\n{} does not implement _print_{}(self, expr)." msg += f"\nExpression is {expr}." msg += "\nExpression type MRO is:" msg += "\n *" + "\n *".join(t.__name__ for t in type(expr).__mro__) msg = msg.format(self.__class__.__name__, expr.__class__.__name__) raise NotImplementedError(msg)
[docs] class LatexNamePrinter(BasePrinter, LatexPrinter): def _print(self, expr, **kwds): if hasattr(expr, "latex_name"): return expr.latex_name elif hasattr(expr, "_latex_name"): return expr._latex_name return super()._print(expr, **kwds) def _print_Derivative(self, expr): return super().print_Derivative(expr)[3] def _print_int(self, expr): return str(expr)
[docs] def emptyPrinter(self, expr): msg = "\n{} does not implement _print_{}(self, expr)." msg += f"\nExpression is {expr}." msg += "\nExpression type MRO is:" msg += "\n *" + "\n *".join(t.__name__ for t in type(expr).__mro__) msg = msg.format(self.__class__.__name__, expr.__class__.__name__) raise NotImplementedError(msg)
pbn = NamePrinter() ppn = PrettyNamePrinter() # pvn = VarNamePrinter() pln = LatexNamePrinter()
[docs] def to_str(*args): if len(args) == 1: args = to_tuple(args[0]) def _to_str(x): return str(x) return tuple(_to_str(y) for y in args)
# exponents formatting functions
[docs] def bexp_fn(x): return f"^{x}" if (x > 1) else ""
pexp_fn = lambda x, sep=",": exponents(x, sep=sep) if (x > 1) else ""
[docs] def vexp_fn(x): return f"e{x}" if (x > 1) else ""
[docs] def lexp_fn(x): return f"^<LBRACKET>{x}<RBRACKET>" if (x > 1) else ""
# powers formatting functions
[docs] def bpow_fn(x): return f"**{x}" if (x > 1) else ""
ppow_fn = lambda x, sep=",": exponents(x, sep=sep) if (x > 1) else ""
[docs] def vpow_fn(x): return f"p{x}" if (x > 1) else ""
[docs] def lpow_fn(x): return f"^<LBRACKET>{x}<RBRACKET>" if (x > 1) else ""
# subcripts formatting functions
[docs] def bsub_fn(x): return f"_{x}" if (x is not None) else ""
psub_fn = lambda x, sep=",": subscripts(x, sep=sep) if (x is not None) else ""
[docs] def vsub_fn(x): return f"s{x}" if (x is not None) else ""
[docs] def lsub_fn(x): return f"_<LBRACKET>{x}<RBRACKET>" if (x is not None) else ""
# components formatting functions
[docs] def bcomp_fn(x): return ",".join(to_str(x)) if (x is not None) else ""
pcomp_fn = lambda x, sep=",": subscripts(x, sep=sep) if (x is not None) else ""
[docs] def vcomp_fn(x): return "_" + "_".join(to_str(x)) if (x is not None) else ""
[docs] def lcomp_fn(x): return ( "_<LBRACKET>{}<RBRACKET>".format(",".join(to_str(x))) if (x is not None) else "" )
# join formatting functions
[docs] def bjoin_fn(x): return "_".join(to_str(x)) if (x is not None) else ""
[docs] def pjoin_fn(x): return "".join(to_str(x)) if (x is not None) else ""
[docs] def vjoin_fn(x): return "_".join(to_str(x)) if (x is not None) else ""
[docs] def ljoin_fn(x): return "".join(to_str(x)) if (x is not None) else ""
# divide formatting functions
[docs] def bdivide_fn(x, y): return f"{x}/{y}"
[docs] def pdivide_fn(x, y): return "{}/{}".format(*to_str(x, y))
[docs] def vdivide_fn(x, y): return f"{x}__{y}"
[docs] def ldivide_fn(x, y): return rf"\dfrac<LBRACKET>{x}<RBRACKET><LBRACKET>{y}<RBRACKET>"
[docs] class DifferentialStringFormatter: """ Utility class to format differential related strings like partial derivatives. All string formatting function returns 4 different results: *A string that can be used as identifier (name). *A pretty string in utf-8 (pretty_name). *A variable name that can be used as a valid C identifier for code generation (var_name). *A latex string that can be compiled and displayed with latex (latex_name). Prefix used for methods: b = name p = pretty_name v = var_name l = latex_name See __main__ at the bottom of this file for usage. """ exp_fns = (bexp_fn, pexp_fn, vexp_fn, lexp_fn) pow_fns = (bpow_fn, ppow_fn, vpow_fn, lpow_fn) sub_fns = (bsub_fn, psub_fn, vsub_fn, lsub_fn) comp_fns = (bcomp_fn, pcomp_fn, vcomp_fn, lcomp_fn) join_fns = (bjoin_fn, pjoin_fn, vjoin_fn, ljoin_fn) divide_fns = (bdivide_fn, pdivide_fn, vdivide_fn, ldivide_fn)
[docs] @staticmethod def format_special_characters(ss): special_characters = { "<LBRACKET>": "{", "<RBRACKET>": "}", } for k, v in special_characters.items(): ss = ss.replace(k, v) return ss
[docs] @classmethod def return_names(cls, *args, **kwds): # fsc = format special characters fsc = kwds.get("fsc", True) assert len(args) >= 1 if len(args) == 1: if fsc: return args[0] else: cls.format_special_characters(args[0]) else: if fsc: return tuple(cls.format_special_characters(a) for a in args) else: return args
[docs] @classmethod def format_partial_name( cls, bvar, pvar, vvar, lvar, bpow_fn=bpow_fn, ppow_fn=ppow_fn, vpow_fn=vpow_fn, lpow_fn=lpow_fn, bcomp_fn=bcomp_fn, pcomp_fn=pcomp_fn, vcomp_fn=vcomp_fn, lcomp_fn=lcomp_fn, blp="(", plp="", vlp="", llp="", brp=")", prp="", vrp="", lrp="", bd="d", pd=partial, vd="d", ld=r"<LBRACKET>\partial<RBRACKET>", dpow=1, varpow=1, components=None, trigp=3, fsc=True, ): assert varpow != 0 bd = "" if (dpow == 0) else bd pd = "" if (dpow == 0) else pd vd = "" if (dpow == 0) else vd ld = "" if (dpow == 0) else ld blp = "" if len(bvar) <= trigp else blp brp = "" if len(bvar) <= trigp else brp plp = "" if len(pvar) <= trigp else plp prp = "" if len(pvar) <= trigp else prp vlp = "" if len(vvar) <= trigp else vlp vrp = "" if len(vvar) <= trigp else vrp llp = "" if len(lvar) <= trigp else llp lrp = "" if len(lvar) <= trigp else lrp template = "{d}{dpow}{lp}{var}{components}{rp}{varpow}" bname = template.format( d=bd, dpow=bpow_fn(dpow), components=bcomp_fn(components), var=bvar, varpow=bpow_fn(varpow), lp=blp, rp=brp, ) pname = template.format( d=pd, dpow=ppow_fn(dpow), components=pcomp_fn(components), var=pvar, varpow=ppow_fn(varpow), lp=plp, rp=prp, ) vname = template.format( d=vd, dpow=vpow_fn(dpow), components=vcomp_fn(components), var=vvar, varpow=vpow_fn(varpow), lp=vlp, rp=vrp, ) lname = template.format( d=ld, dpow=lpow_fn(dpow), components=lcomp_fn(components), var=lvar, varpow=lpow_fn(varpow), lp=llp, rp=lrp, ) return cls.return_names(bname, pname, vname, lname, fsc=fsc)
[docs] @classmethod def format_partial_names( cls, bvars, pvars, vvars, lvars, varpows, bjoin_fn=bjoin_fn, pjoin_fn=pjoin_fn, vjoin_fn=vjoin_fn, ljoin_fn=ljoin_fn, components=None, fsc=True, **kwds, ): bvars, pvars, vvars, lvars = ( to_tuple(bvars), to_tuple(pvars), to_tuple(vvars), to_tuple(lvars), ) varpows = to_tuple(varpows) assert len(bvars) == len(pvars) == len(vvars) == len(lvars) == len(varpows) assert any(v > 0 for v in varpows) nvars = len(bvars) if components is not None: components = to_tuple(components) assert len(components) == nvars else: components = (None,) * nvars bnames, pnames, vnames, lnames = (), (), (), () for bvar, pvar, vvar, lvar, varpow, component in zip( bvars, pvars, vvars, lvars, varpows, components ): if varpow == 0: continue res = cls.format_partial_name( bvar=bvar, pvar=pvar, vvar=vvar, lvar=lvar, varpow=varpow, components=component, fsc=False, **kwds, ) assert len(res) == 4 bnames += (res[0],) pnames += (res[1],) vnames += (res[2],) lnames += (res[3],) return cls.return_names( bjoin_fn(bnames), pjoin_fn(pnames), vjoin_fn(vnames), ljoin_fn(lnames), fsc=fsc, )
[docs] @classmethod def format_pd( cls, bvar, pvar, vvar, lvar, bxvars="x", pxvars=xsymbol, vxvars="x", lxvars="x", varpows=1, var_components=None, xvars_components=None, bdivide_fn=bdivide_fn, pdivide_fn=pdivide_fn, vdivide_fn=vdivide_fn, ldivide_fn=ldivide_fn, fsc=True, **kwds, ): for k in ("dpow", "components", "bvars", "pvars", "vvars", "lvars", "varpow"): assert k not in kwds, f"Cannot specify reserved keyword {k}." bxvars, pxvars, vxvars, lxvars = ( to_tuple(bxvars), to_tuple(pxvars), to_tuple(vxvars), to_tuple(lxvars), ) varpows = to_tuple(varpows) assert len(bxvars) == len(pxvars) == len(vxvars) == len(lxvars) == len(varpows) assert any(v > 0 for v in varpows) dpow = sum(varpows) numerator = cls.format_partial_name( bvar=bvar, pvar=pvar, vvar=vvar, lvar=lvar, fsc=False, dpow=dpow, components=var_components, **kwds, ) denominator = cls.format_partial_names( bvars=bxvars, pvars=pxvars, vvars=vxvars, lvars=lxvars, fsc=False, varpows=varpows, components=xvars_components, **kwds, ) return cls.return_names( bdivide_fn(numerator[0], denominator[0]), pdivide_fn(numerator[1], denominator[1]), vdivide_fn(numerator[2], denominator[2]), ldivide_fn(numerator[3], denominator[3]), fsc=fsc, )
if __name__ == "__main__": def _print(*args, **kwds): if isinstance(args[0], tuple): assert len(args) == 1 args = args[0] if ("multiline" in kwds) and (kwds["multiline"] is True): for a in args: print(a) else: print(", ".join(a for a in args)) print bvar, pvar, vvar, lvar = ( "Fext", "Fₑₓₜ", "Fext", "<LBRACKET>F_<LBRACKET>ext<RBRACKET><RBRACKET>", ) _print(DifferentialStringFormatter.return_names(bvar, pvar, vvar, lvar)) print _print( DifferentialStringFormatter.format_partial_name(bvar, pvar, vvar, lvar, dpow=0) ) _print( DifferentialStringFormatter.format_partial_name(bvar, pvar, vvar, lvar, dpow=1) ) _print( DifferentialStringFormatter.format_partial_name(bvar, pvar, vvar, lvar, dpow=2) ) _print( DifferentialStringFormatter.format_partial_name( bvar, pvar, vvar, lvar, dpow=3, components=0 ) ) _print( DifferentialStringFormatter.format_partial_name( bvar, pvar, vvar, lvar, dpow=4, components=(0, 2) ) ) print bvar, pvar, vvar, lvar = ("x",) * 4 _print( DifferentialStringFormatter.format_partial_name( bvar, pvar, vvar, lvar, varpow=1 ) ) _print( DifferentialStringFormatter.format_partial_name( bvar, pvar, vvar, lvar, varpow=2 ) ) _print( DifferentialStringFormatter.format_partial_name( bvar, pvar, vvar, lvar, varpow=3, components=0 ) ) _print( DifferentialStringFormatter.format_partial_name( bvar, pvar, vvar, lvar, varpow=4, components=(0, 2) ) ) print bvar, pvar, vvar, lvar = (("x", "y"),) * 4 try: _print( DifferentialStringFormatter.format_partial_names( bvar, pvar, vvar, lvar, varpows=(0, 0) ) ) raise RuntimeError() except AssertionError: pass _print( DifferentialStringFormatter.format_partial_names( bvar, pvar, vvar, lvar, varpows=(0, 1) ) ) _print( DifferentialStringFormatter.format_partial_names( bvar, pvar, vvar, lvar, varpows=(1, 0) ) ) _print( DifferentialStringFormatter.format_partial_names( bvar, pvar, vvar, lvar, varpows=(1, 1) ) ) _print( DifferentialStringFormatter.format_partial_names( bvar, pvar, vvar, lvar, varpows=(1, 2) ) ) _print( DifferentialStringFormatter.format_partial_names( bvar, pvar, vvar, lvar, varpows=(2, 2) ) ) _print( DifferentialStringFormatter.format_partial_names( bvar, pvar, vvar, lvar, varpows=(2, 2), components=(0, 1) ) ) _print( DifferentialStringFormatter.format_partial_names( bvar, pvar, vvar, lvar, varpows=(2, 2), components=((0, 1), (1, 0)) ) ) print bvar, pvar, vvar, lvar = ( "Fext", "Fₑₓₜ", "Fext", "<LBRACKET>F_<LBRACKET>ext<RBRACKET><RBRACKET>", ) bxvars, pxvars, vxvars, lxvars = (("x", "y"),) * 4 _print(DifferentialStringFormatter.format_pd(bvar, pvar, vvar, lvar)) _print(DifferentialStringFormatter.format_pd(bvar, pvar, vvar, lvar, varpows=2)) _print( DifferentialStringFormatter.format_pd( bvar, pvar, vvar, lvar, bxvars, pxvars, vxvars, lxvars, varpows=(1, 0) ) ) _print( DifferentialStringFormatter.format_pd( bvar, pvar, vvar, lvar, bxvars, pxvars, vxvars, lxvars, varpows=(0, 1) ) ) _print( DifferentialStringFormatter.format_pd( bvar, pvar, vvar, lvar, bxvars, pxvars, vxvars, lxvars, varpows=(1, 1) ) ) _print( DifferentialStringFormatter.format_pd( bvar, pvar, vvar, lvar, bxvars, pxvars, vxvars, lxvars, varpows=(5, 2) ) ) print bxvars, pxvars, vxvars, lxvars = (("x",) * 5,) * 4 varpows = (1,) * 5 xvars_components = tuple(range(5)) var_components = (0, 4, 3, 2) _print( DifferentialStringFormatter.format_pd( bvar, pvar, vvar, lvar, bxvars, pxvars, vxvars, lxvars, varpows=varpows, xvars_components=xvars_components, var_components=var_components, ), multiline=True, )